#!/usr/bin/env python3

import argparse
import subprocess
import sys
import time

from core import Core
from core import get_node
from core import start_node
from core import stop_node
from google.cloud import tpu_v2
from provision import destroy
from provision import provision


tpu_vm_cmd_pref = ["gcloud", "compute", "tpus", "tpu-vm"]


def _extract_short_name(node: tpu_v2.Node) -> str:
  return node.name.split("/")[-1]


def _parse_node_name(node: tpu_v2.Node) -> tuple[str, str, str]:
  _, project, _, zone, _, name = node.name.split("/")
  return project, zone, name


def vm_run_ssh(node: tpu_v2.Node, cmd: str) -> None:
  """ssh onto a TPU VM and run a command."""
  project, zone, name = _parse_node_name(node)
  start = time.time()
  print(f"Running '{cmd}' on {name} ...\n")
  res = subprocess.run(
      [
          *tpu_vm_cmd_pref,
          "ssh",
          name,
          f"--project={project}",
          f"--zone={zone}",
          f"--command={cmd}",
      ],
      stderr=subprocess.PIPE,
  )
  if res.returncode != 0:
    print(res.stderr.decode())
    raise RuntimeError(f"Execution returned code is {res.returncode}")
  print(f"\nDone, elapsed time: {time.time() - start} seconds")


def vm_scp(node: tpu_v2.Node, src_path: str, tgt_path: str) -> None:
  project, zone, name = _parse_node_name(node)
  print(f"Copying {src_path} to {name} ...")
  subprocess.run(
      [
          *tpu_vm_cmd_pref,
          "scp",
          src_path,
          f"{name}:{tgt_path}",
          f"--project={project}",
          f"--zone={zone}",
      ],
      check=True,
  )


class UxError(Exception):
  pass


def _parse_with_core(
    parser: argparse.ArgumentParser,
) -> tuple[argparse.Namespace, Core]:
  Core.add_args(parser)
  args = parser.parse_args(sys.argv[2:])
  return args, Core.from_args(args)


class Tool:
  """Tool to provision TPUs and run FHE programs on them."""

  def __init__(self):
    self._default_im_sa = "tpursuit-deployer"

    parser = argparse.ArgumentParser(
        description="Tool to run executables on TPUs",
        usage="""tool <command> [<args>]
   run           Run an executable on a TPU VM
   list          List all TPU VMs

   Managing infrastructure:
   provision     Provision a TPU VM and supporting infrastructure
   destroy       Destroy a TPU VM and supporting infrastructure

   Managing TPU VMs state:
   stop          Stop a TPU VM
   start         Start a TPU VM
""",
    )
    parser.add_argument("command", help="Subcommand to run")
    args = parser.parse_args(sys.argv[1:2])
    if not hasattr(self, args.command):
      print("Unrecognized command")
      parser.print_help()
      exit(1)
    getattr(self, args.command)()

  def list(self):
    parser = argparse.ArgumentParser(description="List all TPU VMs")
    parser.add_argument(
        "--all",
        action="store_true",
        default=False,
        help="List all available zones",
    )
    args, core = _parse_with_core(parser)

    self._print_list(core.list_nodes(all_zones=args.all))

  def _print_list(self, nodes) -> None:
    if not nodes:
      print("No VMs found")
      return

    print("Name\tStatus\tAccelerator\tZone\tSubnet")
    for node in nodes:
      _, _, _, zone, _, name = node.name.split("/")
      *_, sn = node.network_config.subnetwork.split("/")
      print(f"{name}\t{node.state.name}\t{node.accelerator_type}\t{zone}\t{sn}")

  def run(self):
    parser = argparse.ArgumentParser(
        description="Run an executable on a TPU VM"
    )
    parser.add_argument("--files", help="Files required to run fhe code")
    parser.add_argument("--main", help="Main file", required=True)
    parser.add_argument("--vm", type=str, help="Name of the TPU VM to run on")
    parser.add_argument(
        "--keep-running",
        action="store_true",
        default=False,
        help="Keep the VM running after the script finishes",
    )
    args, core = _parse_with_core(parser)

    nodes = core.list_nodes(all_zones=False)
    if args.vm is None:
      if len(nodes) == 0:
        raise UxError("No VMs found")
      if len(nodes) > 1:
        self._print_list(nodes)
        raise UxError("Multiple VMs found, please specify one")

      node = nodes[0]
      print(f"Using only VM in zone={core.zone}: {_extract_short_name(node)}")
    else:
      match = [node for node in nodes if args.vm == _extract_short_name(node)]
      if not match:
        self._print_list(nodes)
        raise UxError(f"VM {args.vm} not found in zone={core.zone}")
      node = match[0]

    src_files_to_copy = args.files.split(",")
    if args.main not in src_files_to_copy:
      src_files_to_copy.append(args.main)
    py_cmds = ["python3 " + args.main.split("/")[-1]]

    if node.state.name != "READY":  # TODO:Handle states properly
      start_node(node)
    try:
      for f in src_files_to_copy:
        vm_scp(node, f, f.split("/")[-1])
        py_cmds.append("rm " + f.split("/")[-1])
      vm_run_ssh(node, "; ".join(py_cmds))
    finally:
      if not args.keep_running:
        stop_node(node)

  def stop(self):
    parser = argparse.ArgumentParser(description="Stop a TPU VM")
    parser.add_argument("vm", type=str, help="Name of the TPU VM to stop")
    args, core = _parse_with_core(parser)

    stop_node(get_node(core.parent, args.vm))

  def start(self):
    parser = argparse.ArgumentParser(description="Start a TPU VM")
    parser.add_argument("vm", type=str, help="Name of the TPU VM to start")
    args, core = _parse_with_core(parser)

    start_node(get_node(core.parent, args.vm))

  def provision(self):
    """Provision a TPU VM and supporting infrastructure."""
    parser = argparse.ArgumentParser(
        description="Provision a TPU VM and supporting infrastructure"
    )
    parser.add_argument("vm", type=str, help="Name of the TPU VM to provision")
    parser.add_argument(
        "--runtime-version",
        type=str,
        # https://cloud.google.com/tpu/docs/runtimes
        default="v2-alpha-tpuv5-lite",
        help="Runtime version",
    )
    parser.add_argument(
        "--accelerator-type",
        type=str,
        # https://cloud.google.com/tpu/docs/system-architecture-tpu-vm
        default="v5litepod-4",
        help="Accelerator type",
    )
    parser.add_argument(
        "--keep-running",
        action="store_true",
        default=False,
        help="Keep the VM running after the script finishes",
    )
    parser.add_argument(
        "--subnet",
        type=str,
        help=(
            "Subnet to use, if none provided a new network infrastructure will"
            " be created"
        ),
    )
    args, core = _parse_with_core(parser)

    provision(
        core=core,
        name=args.vm,
        runtime_version=args.runtime_version,
        accelerator_type=args.accelerator_type,
        subnet_url=args.subnet,
        keep_running=args.keep_running,
    )

  def destroy(self):
    """Destroys TPU VM and supporting infrastructure."""
    parser = argparse.ArgumentParser(
        description="Destroy a TPU VM and supporting infrastructure"
    )
    parser.add_argument("vm", type=str, help="Name of the TPU VM to destroy")
    parser.add_argument(
        "--keep-network",
        action="store_true",
        default=False,
        help="Keep the network infrastructure",
    )
    args, core = _parse_with_core(parser)

    destroy(core, args.vm, keep_network=args.keep_network)


if __name__ == "__main__":
  try:
    Tool()
  except UxError as e:
    print(e)
    exit(1)
